-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Unify extract_patches to support both 2D and 3D patches #21980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unify extract_patches to support both 2D and 3D patches #21980
Conversation
Summary of ChangesHello @MarcosAsh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request unifies the extract_patches function to support both 2D and 3D patches, which is a great step towards a more consistent API. The implementation looks good, and the added tests cover the new functionality well.
I have a few suggestions to further improve the user experience:
- The docstrings for
images,size, anddata_formatcould be more explicit to improve clarity, especially with the new 3D support. - I've also suggested adding more robust input validation at the beginning of the
extract_patchesfunction to provide clearer error messages for invalidsizearguments, along with the corresponding test updates. This aligns with Keras' API design guidelines for helpful error messages.
| images: Input image/volume or batch of images/volumes. | ||
| For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`. | ||
| For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`. | ||
| size: Patch size as int or tuple. | ||
| Length 2 tuple `(patch_height, patch_width)` for 2D patches. | ||
| Length 3 tuple `(patch_depth, patch_height, patch_width)` for | ||
| 3D patches. | ||
| strides: Strides for patch extraction. If not specified, defaults | ||
| to `size` (non-overlapping patches). | ||
| dilation_rate: Dilation rate for patch extraction. Note that | ||
| `dilation_rate > 1` is not supported with `strides > 1`. | ||
| padding: The type of padding algorithm to use: `"same"` or `"valid"`. | ||
| data_format: A string specifying the data format of the input tensor. | ||
| It can be either `"channels_last"` or `"channels_first"`. | ||
| `"channels_last"` corresponds to inputs with shape | ||
| `(batch, height, width, channels)`, while `"channels_first"` | ||
| corresponds to inputs with shape `(batch, channels, height, width)`. | ||
| If not specified, the value will default to | ||
| `keras.config.image_data_format`. | ||
| If not specified, defaults to `keras.config.image_data_format`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstrings for images, size, and data_format could be more explicit to improve clarity, especially with the new 3D support. The current images docstring assumes channels_last without stating it, the size docstring is ambiguous about integer values, and the data_format docstring is a bit sparse. I suggest clarifying these points for a better user experience.
images: Input image/volume or batch of images/volumes. Assumes
`channels_last` data format.
For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
size: The size of the patches to extract.
- If `size` is an integer, 2D patches of size `(size, size)` are
extracted.
- If `size` is a tuple of 2 integers, 2D patches of size
`(patch_height, patch_width)` are extracted.
- If `size` is a tuple of 3 integers, 3D patches of size
`(patch_depth, patch_height, patch_width)` are extracted.
strides: Strides for patch extraction. If not specified, defaults
to `size` (non-overlapping patches).
dilation_rate: Dilation rate for patch extraction. Note that
`dilation_rate > 1` is not supported with `strides > 1`.
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
data_format: A string specifying the data format of the input tensor.
It can be either `"channels_last"` or `"channels_first"`.
`"channels_last"` corresponds to inputs with shape
`(batch, ..., channels)`, while `"channels_first"` corresponds
to inputs with shape `(batch, channels, ...)`. If not specified,
defaults to `keras.config.image_data_format`.| def test_extract_patches_invalid_size(self): | ||
| size = (3, 3, 3) # Invalid size, too many dimensions | ||
| size = "5" # Invalid size type | ||
| image = np.random.uniform(size=(2, 20, 20, 3)) | ||
| with self.assertRaisesRegex( | ||
| TypeError, "Expected an int or a tuple of length 2" | ||
| ): | ||
| kimage.extract_patches(image, size) | ||
|
|
||
| size = "5" # Invalid size type | ||
| size = (3, 3, 3, 3) # Invalid size, too many dimensions | ||
| with self.assertRaisesRegex( | ||
| TypeError, "Expected an int or a tuple of length 2" | ||
| ): | ||
| kimage.extract_patches(image, size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test should be updated to reflect the improved validation in extract_patches. The error types and messages for invalid size will change with the suggested validation logic.
| def test_extract_patches_invalid_size(self): | |
| size = (3, 3, 3) # Invalid size, too many dimensions | |
| size = "5" # Invalid size type | |
| image = np.random.uniform(size=(2, 20, 20, 3)) | |
| with self.assertRaisesRegex( | |
| TypeError, "Expected an int or a tuple of length 2" | |
| ): | |
| kimage.extract_patches(image, size) | |
| size = "5" # Invalid size type | |
| size = (3, 3, 3, 3) # Invalid size, too many dimensions | |
| with self.assertRaisesRegex( | |
| TypeError, "Expected an int or a tuple of length 2" | |
| ): | |
| kimage.extract_patches(image, size) | |
| def test_extract_patches_invalid_size(self): | |
| size = "5" # Invalid size type | |
| image = np.random.uniform(size=(2, 20, 20, 3)) | |
| with self.assertRaisesRegex( | |
| TypeError, "Expected an int or a tuple" | |
| ): | |
| kimage.extract_patches(image, size) | |
| size = (3, 3, 3, 3) # Invalid size, too many dimensions | |
| with self.assertRaisesRegex( | |
| ValueError, "Expected a tuple of length 2 or 3" | |
| ): | |
| kimage.extract_patches(image, size) |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21980 +/- ##
========================================
Coverage 82.69% 82.70%
========================================
Files 588 588
Lines 61369 61570 +201
Branches 9612 9641 +29
========================================
+ Hits 50749 50919 +170
- Misses 8135 8156 +21
- Partials 2485 2495 +10
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the refactor!
One small comment:
keras/src/ops/image.py
Outdated
| For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`. | ||
| For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`. | ||
| size: Patch size as int or tuple. | ||
| Length 2 tuple `(patch_height, patch_width)` for 2D patches. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small tweak: mention that if the size is an int it will do 2D patches:
size: Patch size as int or tuple.
Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.
|
Hey just committed the changes thanks for the help! |
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this.
Would you be able to do this as a next step in a separate PR?
This is to fully merge both ops in a single one and have only one op class:
- rename the current
_extract_patchesto_extract_patches_2d - create a new
_extract_patchesfunction that does
if not isinstance(size, int) and len(size) == 3:
return _extract_patches_3d(...)
else:
return _extract_patches_2d(...)Note that this will make ExtractPatches work with 3D too.
- revert
extract_patchesso that it only does:
if any_symbolic_tensors((images,)):
return ExtractPatches(
size=size,
strides=strides,
dilation_rate=dilation_rate,
padding=padding,
data_format=data_format,
).symbolic_call(images)
return _extract_patches(
images, size, strides, dilation_rate, padding, data_format=data_format
)- Remove the
ExtractPatches3Dclass
That would be great. Thanks!
Summary
Unified extract_patches to handle both 2D and 3D patch extraction based on the size argument, similar to how keras.ops.conv handles different dimensions.
Changes
size=(h, w) -> 2D patch extraction. This already existed.
size=(d, h, w) -> 3D patch extraction
Why
As suggested in #21624, having a unified API is more consistent with other Keras ops like conv.
Tests
Updated test_extract_patches_invalid_size to reflect new valid 3D sizes.
Added test_extract_patches_unified_3d for the new functionality.
All existing extract_patches tests pass.
Relates to #21624